#include "cacla.h"



cacla::cacla(mlp *actor, mlp *critic, simulator *sim): actor(actor), critic(critic), sim(sim)
{
	gamma = 0.0;
	stepsize = 0.300025; //0,25

	
	fopen_s(&fout, "./m2/output.txt", "w");
	fopen_s(&fst, "./m2/state.txt", "w");

}

cacla::~cacla(void){
	fclose(fout);
	fclose(fst);
}

gsl_vector *cacla::modifyAction(gsl_vector *action, double maxdiff){
	gsl_vector *result = gsl_vector_calloc(action->size);

	for(int i = 0; i < (int)action->size; i++){
		double diff = 2 * maxdiff * rand() / (float)RAND_MAX;
		diff -= maxdiff;
		gsl_vector_set(result, i, gsl_vector_get(action, i) + diff);
	}

	return result;
}

gsl_vector *cacla::randomAction(unsigned int size, double maxdiff){
	gsl_vector *result = gsl_vector_alloc(size);

	for(unsigned int i = 0; i < size; i++){
		double diff = 2 * maxdiff * rand() / (float)RAND_MAX;
		diff -= maxdiff;
		gsl_vector_set(result, i, diff);
	}

	return result;
}
double cacla::learnMove(gsl_vector *targetv, gsl_vector *actionv){
	
	gsl_vector *arm = sim->arm_get(PART_RIGHT_ARM, 4);
	sim->scale_arm(arm);		

	gsl_vector *state_t = combineVectors(arm, targetv); //skombinujem do vstupu pre kritika
	state_t = combineVectorsFreeFirst(state_t, actionv);
	state_t = addBit(state_t, sim->rhand_touch());
	state_t = addBias(state_t); //prida bias a uvolni povodny vektor

	//position rhand_t = sim->world_get_rhand(); //pozicia dlane v case t
	double reward_t = getReward(targetv, getOneHotActive(actionv));

	gsl_vector *value_t_s = critic->runNetwork(state_t);
	//double d_t = distance_euclidean(sim->world_get_rhand(), target) - gsl_vector_get(value_t, 0);

	gsl_vector *action_t = actor->runNetwork(state_t);
	gsl_vector_scale(action_t, stepsize);
	gsl_vector *action_t_mod;

	double randomScale = distance_euclidean(sim->world_get_rhand(), sim->getObjectPosition(targetv)); // 0.25 - 0.5


	if(rand() > RAND_MAX * 0.60){
		gsl_vector_set_all(action_t, 0.0);
		action_t_mod = modifyAction(action_t, randomScale - 0.10); //stepsize/1.745
	}else{
		action_t_mod = modifyAction(action_t, 0.0);
	}

	gsl_vector *arm_t_next = nextArmPosition(arm, action_t_mod);

	sim->scale_back_arm(arm_t_next);
	sim->arm_set(arm_t_next, PART_RIGHT_ARM);

	//position rhand_tn = sim->world_get_rhand();
	double reward_tn = getReward(targetv, getOneHotActive(actionv));
	
	gsl_vector_free(arm);
	arm = sim->arm_get(PART_RIGHT_ARM, 4);
	sim->scale_arm(arm);

	gsl_vector *state_tn = combineVectors(arm, targetv);
	state_tn = combineVectorsFreeFirst(state_tn, actionv);
	state_tn = addBit(state_tn, sim->rhand_touch());
	state_tn = addBias(state_tn);

	gsl_vector *value_t_sn = critic->runNetwork(state_tn);

	if(reward_t > 1.0 || reward_t < - 1.0){
		printf("WARNING: Reward out of range \n");
	}

	double delta = reward_tn + gamma * gsl_vector_get(value_t_sn, 0) - gsl_vector_get(value_t_s, 0);
	gsl_vector *critic_desired = gsl_vector_alloc(1);
	gsl_vector_set(critic_desired, 0, gsl_vector_get(value_t_s, 0) + 1.0 * delta);

	critic->trainNetwork(state_t, critic_desired, 0.01); // 0.01 optimalne

	gsl_vector *value_tn_s = critic->runNetwork(state_t);

	double d_t_s = gsl_vector_get(critic_desired, 0) - gsl_vector_get(value_t_s, 0);
	//double d_tn_s = gsl_vector_get(critic_desired, 0) - gsl_vector_get(value_tn_s, 0);


	//fprintf(fo, "%f \t %f \t %f \t %f \t %f \t %f \n", d_t_s * d_t_s / 2.0, d_tn_s * d_tn_s / 2.0, reward_t, gsl_vector_get(critic_desired, 0), gsl_vector_get(value_t_s, 0), gsl_vector_get(value_tn_s, 0));

	gsl_vector *value_tn_sn = critic->runNetwork(state_tn);
	//if(delta > 0.0){
	//if(reward_tn > reward_t){
	if(value_t_sn->data[0] > value_t_s->data[0]){
		gsl_vector_scale(action_t_mod, 1.0 / stepsize);

		actor->trainNetwork(state_t, action_t_mod, 0.0005); // 0.001 asi optimum
		printf("L| ");
	}else{
		printf("_| ");
	}
	
	print_vector_to_file(state_t, f);
	print_vector_to_file(critic_desired, fdes);
	
	print_vector_to_file(state_t, fActor);
	print_vector_to_file(action_t_mod, fActorDes);

	printf("C_error: %+f | C_reward: %+f  | r_t: %+f | r_tn: %+f\n", d_t_s * d_t_s / 2.0, value_t_s->data[0], reward_t, reward_tn);
//gsl_vector_free(value_tn_sn); 

	gsl_vector_free(arm);
	gsl_vector_free(state_t);
	gsl_vector_free(value_t_s);
	gsl_vector_free(action_t);
	gsl_vector_free(action_t_mod);
	gsl_vector_free(arm_t_next);
	gsl_vector_free(state_tn);
	gsl_vector_free(value_t_sn);
	gsl_vector_free(critic_desired);
	gsl_vector_free(value_tn_s);
	//gsl_vector_free(targetv);
	return reward_tn;
}

double cacla::makeMove(gsl_vector *targetv, gsl_vector *actionv){
	
	gsl_vector *arm = sim->arm_get(PART_RIGHT_ARM, 4);
	sim->scale_arm(arm);		

	gsl_vector *state_t = combineVectors(arm, targetv); //skombinujem do vstupu pre kritika
	state_t = combineVectorsFreeFirst(state_t, actionv);
	state_t = addBit(state_t, sim->rhand_touch());
	state_t = addBias(state_t); //prida bias a uvolni povodny vektor

//print_vector_to_file(state_t, fst);

	gsl_vector *action_t = actor->runNetwork(state_t);

//print_vector_to_file(action_t, fout);

	gsl_vector_scale(action_t, stepsize); // *3.0
	gsl_vector *arm_t_next = nextArmPosition(arm, action_t);

	sim->scale_back_arm(arm_t_next);
	sim->arm_set(arm_t_next, PART_RIGHT_ARM);

	gsl_vector_free(arm);
	gsl_vector_free(state_t);
	gsl_vector_free(action_t);
	gsl_vector_free(arm_t_next);
	double reward_tn = getReward(targetv, getOneHotActive(actionv));
	return reward_tn;
}

//limitovany do rozsahu <0.0 - 1.0>!
gsl_vector *cacla::nextArmPosition(gsl_vector *state, gsl_vector *diff){
	gsl_vector *result = gsl_vector_calloc(state->size);
	gsl_vector_memcpy(result, state);
	gsl_vector_add(result, diff);

	//orezanie do intervalu <-1.0 - 1.0>
	for(unsigned int i = 0; i < result->size; i++){
		gsl_vector_set(result, i, max(min(gsl_vector_get(result, i), 1.0), -1.0));
	}
	
	//uprava diff, aby odzrkadloval orezanie do intervalu <0.0 - 1.0>
	gsl_vector_memcpy(diff, result);
	gsl_vector_sub(diff, state);

	return result;
}

/*
double cacla::getReward(position targetPos, position hand_position_t){
	return -distance_euclidean(targetPos, hand_position_t) * 1.0 + 1.0;
}
*/
double cacla::getReward(gsl_vector *targetv, int action){
	switch(action){
		case ACTION_POINT: 
			return min(1.0, max(-1.0, getRewardPoint(targetv)));
			break;
		case ACTION_TOUCH:
			return min(1.0, max(-1.0, getRewardTouch(targetv)));
			break;
		case ACTION_PUSH:
			return min(1.0, max(-1.0, getRewardPush(targetv)));
			break;
	}
	printf("ERROR: Illegal action");
	return -1.0;
}
double cacla::getRewardTouch(gsl_vector *targetv){
	double bonus = 0;
	position targetPos = sim->getObjectPosition(targetv);
	position hand_position_t = sim->world_get_rhand();

	if(sim->rhand_touch()){
		if(sim->isTargetNearest(targetv)){
			if(sim->movedTarget(targetv)){
				bonus = -0.5; //ak sa dotykam ciela, a posunul som ho
			}else{
				bonus = -0.3; //ak sa dotykam ciela a zaroven som ho neposunul
			}
		}else{
			bonus = -0.5; //ak sa dotykam, ale nie cieloveho, popripade ho posuvam
		}
	}

	return bonus - distance_euclidean(targetPos, hand_position_t) * 2.0 ;
}
double cacla::getRewardPush(gsl_vector *targetv){
	double bonus = 0;
	position targetPos = sim->getObjectStartPosition(targetv);
	targetPos.y += 0.02;
	position hand_position_t = sim->world_get_rhand();
/*
	if(sim->rhand_touch()){
		if(sim->isTargetNearest(targetv)){
			if(sim->movedTarget(targetv)){
				bonus =  0.6; //ak sa dotykam ciela a posunul som ho
			}else{
				bonus =  0.3; //ak sa dotykam ciela, ale zaroven som ho neposunul
			}
		}else{
			bonus =  -0.5; //ak sa dotykam, ale nie cieloveho, popripade ho posuvam
		}
	}
	*/
//bonus - 
	return -distance_euclidean(targetPos, hand_position_t) * 2.0 ;
}
double cacla::getRewardPoint(gsl_vector *targetv){
	gsl_vector *arm = sim->arm_get(PART_RIGHT_ARM, 4);
	gsl_vector *arm_target = gsl_vector_calloc(4);	

	if(gsl_blas_dasum(targetv) > 1){
		printf("ERROR: invalid target vector RewardFunction");
	}

	if(targetv->data[0] == 1){
		arm_target->data[0] = 10;
		arm_target->data[2] = 10;
		arm_target->data[3] = 106;
	}
	if(targetv->data[1] == 1){
		arm_target->data[0] = 10;
		arm_target->data[1] = 0;
		arm_target->data[2] = -15;
		arm_target->data[3] = 106;
	}
	if(targetv->data[2] == 1){
		arm_target->data[0] = 10;
		arm_target->data[1] = 0;
		arm_target->data[2] = -34;
		arm_target->data[3] = 106;
	}

	sim->scale_arm(arm);
	sim->scale_arm(arm_target);
	
	gsl_vector_sub(arm_target, arm);
	double diff = gsl_blas_dasum(arm_target) / 2.0;

	gsl_vector_free(arm);
	gsl_vector_free(arm_target);
	if(sim->rhand_touch()){
		return -1.0; //ak sa dotykam a tym padom aj hybem
	}
	return -diff;
}

void cacla::learnActions(){
	
	FILE *fo;
	fopen_s(&fo, "./m2/rewardc.txt", "a");

	actor->loadNetwork("./cacla/actor/");
	critic->loadNetwork("./cacla/critic/");

	fopen_s(&f, "./critic_input.txt", "w");
	fopen_s(&fdes, "./critic_desired.txt", "w");
	fopen_s(&fActor, "./actor_input.txt", "w");
	fopen_s(&fActorDes, "./actor_desired.txt", "w");

	gsl_vector *targetv = gsl_vector_calloc(3);
	gsl_vector *actionv = gsl_vector_calloc(3);

	unsigned long int globalsteps = 0;
	int steps = 0;
	int iter = 300;  //   100 / 30min
	
	double lastReward = 0.0;
	int action;
	int tposition;

	while(iter > 0){
		steps = 75;

		gsl_vector_set_zero(targetv);
		tposition = rand() % 3;
		targetv->data[tposition] = 1;
		//targetv->data[POSITION_RIGHT] = 1;

		gsl_vector_set_zero(actionv);
		//actionv->data[rand() % 3] = 1;
		if(rand() % 3 == 0){
			action = ACTION_PUSH;
		}
		else{
			action = rand() % 3;
		}
		actionv->data[action] = 1;
		
		//actionv->data[(rand() % 2) + 1] = 1;
		/*
		if(rand() % 2 == 0){
			actionv->data[ACTION_POINT] = 1;
		}else{
			actionv->data[ACTION_PUSH] = 1;
		}
		*/
		sim->unblockArm();
		sim->createObjects();

		printf("%d COMMAND: %s %s \n", iter, toStringAction(getOneHotActive(actionv)), toStringPosition(getOneHotActive(targetv)));

		//print_vector(targetv);
		position apos = sim->getObjectPosition(targetv);
		//position apos = sim->getObjectPosition(targetv);
		//sim->world_mk_sbox(0.002, 0.002, 0.002, apos.x, apos.y, apos.z, 1.0, 0.0, 0.0);
		lastReward = 0.0;
		double cumulReward = 0.0;
		while(steps > 0){			
			lastReward = learnMove(targetv, actionv);
			cumulReward += lastReward;
			//lastReward = cacla1->makeMove(targetv, actionv);
			globalsteps++;
			steps--;
			// (lastReward > -1.0) && !sim->movedNotTarget(targetv) && !sim->movedTarget(targetv) && 
			if(((abs(lastReward) == 0.5) || sim->movedNotTarget(targetv) || (sim->movedTarget(targetv) && !(actionv->data[ACTION_PUSH] == 1))) && (steps > 7)){
				steps = 7;
			}
		}
		if(iter % 10 == 0)
		{
			actor->saveNetwork("./cacla/actor/");
			critic->saveNetwork("./cacla/critic/");
			printf("...saved\n");
		}
		fprintf(fo, "%d %d %f \n", action, tposition, cumulReward);
		iter--;
	}

	gsl_vector_free(targetv);
	gsl_vector_free(actionv);
	actor->saveNetwork("./cacla/actor/");
	critic->saveNetwork("./cacla/critic/");

	fclose(fo);

	fclose(f);
	fclose(fdes);

	fclose(fActor);
	fclose(fActorDes);
}


void cacla::performActions(){

	FILE *fo;
	fopen_s(&fo, "./m2/reward.txt", "w");

	actor->loadNetwork("./cacla/actor/");
	critic->loadNetwork("./cacla/critic/");

	gsl_vector *targetv = gsl_vector_calloc(3);
	gsl_vector *actionv = gsl_vector_calloc(3);

	unsigned long int globalsteps = 0;
	int steps = 0;
	int iter = 9;  //   100 / 30min
	double lastReward = 0.0;

	int action = 0;
	int target = 0;

	while(iter > 0){
		steps = 100;

		gsl_vector_set_zero(targetv);
		targetv->data[target] = 1;

		gsl_vector_set_zero(actionv);

		//actionv->data[rand() % 3] = 1;
		//actionv->data[(rand() % 2) + 1] = 1;
		actionv->data[action] = 1;

		if(target == 2){
			action++;
			target = 0;
		}else{
			target++;
		}

		sim->unblockArm();
		sim->createObjects();

		printf("%d COMMAND: %s %s \n", iter, toStringAction(getOneHotActive(actionv)), toStringPosition(getOneHotActive(targetv)));

		//print_vector(targetv);
		position apos = sim->getObjectPosition(targetv);
		//position apos = sim->getObjectPosition(targetv);
		//sim->world_mk_sbox(0.002, 0.002, 0.002, apos.x, apos.y, apos.z, 1.0, 0.0, 0.0);
		lastReward = 0.0;
		//double cumulReward = 0.0;
		while(steps > 0){			
			//lastReward = cacla1->learnMove(targetv, actionv);
			lastReward = makeMove(targetv, actionv);
			//cacla1->actor->activationHidden;
			
			//cumulReward += lastReward;
			//fprintf(fo, "%f\n", lastReward);
			
			globalsteps++;
			steps--;
			if((lastReward == 1.0) && (steps > 5)){
				//steps = 5;
			}
		}
		//fprintf(fo, "%d %d %f\n", action, position, cumulReward);
		iter--;
	}
	
	fclose(fo);

	gsl_vector_free(targetv);
	gsl_vector_free(actionv);
}

void cacla::generateLanguageFiles(){
	
	int maxstep = 100;

	actor->loadNetwork("./cacla/actor/");
	critic->loadNetwork("./cacla/critic/");

	gsl_vector *targetv = gsl_vector_calloc(3);
	gsl_vector *actionv = gsl_vector_calloc(3);

	gsl_matrix *input_lang = gsl_matrix_alloc(actor->neurons->data[1], 9 * maxstep); //
	gsl_matrix *desired_lang = gsl_matrix_calloc(3, 9 * maxstep); //

	unsigned long int globalsteps = 0;
	int steps = 0;
	int iter = 9;  //   100 / 30min
	double lastReward = 0.0;

	int action = 0;
	int target = 0;

	while(iter > 0){
		steps = maxstep;

		gsl_vector_set_zero(targetv);
		targetv->data[target] = 1;

		gsl_vector_set_zero(actionv);

		//actionv->data[rand() % 3] = 1;
		//actionv->data[(rand() % 2) + 1] = 1;
		actionv->data[action] = 1;

		sim->unblockArm();
		sim->createObjects();

		printf("%d COMMAND: %s %s \n", iter, toStringAction(getOneHotActive(actionv)), toStringPosition(getOneHotActive(targetv)));

		//print_vector(targetv);
		position apos = sim->getObjectPosition(targetv);
		lastReward = 0.0;
		while(steps > 0){			
			lastReward = makeMove(targetv, actionv);
	
			gsl_vector *hidden = gsl_vector_alloc_col_from_matrix(actor->activationHidden, 0);
			gsl_matrix_set_col(input_lang, globalsteps, hidden);
			gsl_matrix_set(desired_lang, 0 + action, globalsteps, 1.0);
			//gsl_matrix_set(desired_lang, 3 + target, globalsteps, 1.0);
			gsl_vector_free(hidden);
			
			globalsteps++;
			steps--;
		}

		if(target == 2){
			action++;
			target = 0;
		}else{
			target++;
		}

		iter--;
	}

	print_matrix_to_file(desired_lang, "./lang/lang_desired_3_randstart.txt");
	print_matrix_to_file(input_lang, "./lang/lang_hidden_3_randstart.txt");

	gsl_matrix_free(input_lang);
	gsl_matrix_free(desired_lang);

	gsl_vector_free(targetv);
	gsl_vector_free(actionv);

}
